{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abe9913d",
   "metadata": {
    "id": "1a0f93c6"
   },
   "outputs": [],
   "source": [
    "BRANCH = 'main'\n",
    "\n",
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd26974d",
   "metadata": {
    "id": "ffdfe626"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "# either provide a path to local NeMo repository with NeMo already installed or git clone\n",
    "\n",
    "# option #1: local path to NeMo repo with NeMo already installed\n",
    "NEMO_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath('')))\n",
    "is_colab = False\n",
    "\n",
    "# option #2: download NeMo repo\n",
    "if 'google.colab' in str(get_ipython()) or not os.path.exists(os.path.join(NEMO_DIR_PATH, \"nemo\")):\n",
    "    ## Install dependencies\n",
    "    !apt-get install sox libsndfile1 ffmpeg\n",
    "\n",
    "    !git clone -b $BRANCH https://github.com/NVIDIA/NeMo\n",
    "    %cd NeMo\n",
    "    !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
    "    NEMO_DIR_PATH = os.path.abspath('')\n",
    "    is_colab = True\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, NEMO_DIR_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3f35d50",
   "metadata": {
    "id": "bcc3e593"
   },
   "source": [
    "# 1. Introduction to ASR confidence estimation\n",
    "Confidence estimation is a crucial yet sometimes overlooked aspect of automatic speech recognition (ASR) systems. Confidence estimation for ASR is the process of estimating the rate of reliability of the output generated by an ASR system. For an output transcription, confidence estimation answers the question \"how accurate this transcription is\", or \"how likely this  transcription is correct\".\n",
    "\n",
    "Confidence score is the result of confidence estimation. It lies in range from 0 to 1, where zero signals that the confidence estimator is completely unsure, and one indicates that the estimator is confident in the output. Confidence scores are often used to guide downstream processing in ASR applications. For example, in a voice dictation application, a low confidence score could trigger the system to ask the user to repeat the input or to suggest alternative transcriptions.\n",
    "\n",
    "There are several approaches to confidence estimation in ASR, including:\n",
    "\n",
    "1. Acoustic modeling-based methods: These methods use the acoustic model scores to estimate the confidence score. The acoustic model represents the relationship between the acoustic signal and the corresponding linguistic units, and the score reflects the similarity between the observed signal and the predicted model output. Here, the acoustic model can be the ASR model itself (non-trainable methods), or a trainable external estimator, accepting acoustic features or output probabilities and predicting confidence scores.\n",
    "\n",
    "2. Language modeling-based methods: These methods use the language model scores to estimate the confidence score. The language model represents the probability distribution of the sequence of words, and the score reflects the likelihood of the transcription given the language model. \n",
    "\n",
    "3. Combination methods: These methods combine the scores from both the acoustic and language models to estimate the confidence score. This approach can leverage the strengths of both models to achieve more accurate confidence scores.\n",
    "\n",
    "In this introductory tutorial we will cover only the non-trainable acoustic-based methods."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34e356bf",
   "metadata": {
    "id": "59100fb9"
   },
   "source": [
    "## 1.1. Optional resources\n",
    "This tutorial is self-contained, but if you want to dive deeper into the topic, you can check out these resources:\n",
    "* Paper behind this tutorial: https://arxiv.org/abs/2212.08703\n",
    "* Supplementary blog on how and why confidence estimation methods of this tutorial were developed: https://developer.nvidia.com/blog/entropy-based-methods-for-word-level-asr-confidence-estimation/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9739cb35",
   "metadata": {
    "id": "cd7226c5"
   },
   "source": [
    "# 2. Data Download\n",
    "First, let's download audio and text data. Here we will use LibriSpeech *dev-other* and *test-other*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46b2861b",
   "metadata": {
    "id": "fd542e62"
   },
   "outputs": [],
   "source": [
    "## create data directory and download an audio file\n",
    "WORK_DIR = 'WORK_DIR'\n",
    "DATA_DIR = WORK_DIR + '/DATA'\n",
    "os.makedirs(DATA_DIR, exist_ok=True)\n",
    "\n",
    "print('downloading audio data...')\n",
    "!python $NEMO_DIR_PATH/scripts/dataset_processing/get_librispeech_data.py --data_root=$DATA_DIR --data_set=test_other\n",
    "!rm $DATA_DIR/test_other.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ba5ad12",
   "metadata": {
    "id": "383eee71"
   },
   "source": [
    "# 3. Confidence estimation example\n",
    "Let's see how confidence scores can be obtained with NeMo models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a95697fe",
   "metadata": {
    "id": "7c7c0170"
   },
   "source": [
    "## 3.1. Helper functions\n",
    "The following functions are to pretty-print confidence scores for word-level ASR hypotheses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd12b7b",
   "metadata": {
    "id": "20cf0b38"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from termcolor import colored\n",
    "from typing import List, Optional, Tuple, Union\n",
    "\n",
    "from IPython.display import Audio, HTML, Image, display\n",
    "import numpy as np\n",
    "import texterrors\n",
    "\n",
    "def get_detailed_wer_labels(ref: List[str], hyp: List[str], return_eps_padded_hyp: bool = False):\n",
    "    \"\"\"Get detailed WER labels, aligning reference with hypothesis.\n",
    "    \n",
    "    Possible WER labels:\n",
    "        - 'C' for Correct,\n",
    "        - 'I' for Insertion,\n",
    "        - 'D' for Deletion,\n",
    "        - 'S' for Substitution.\n",
    "\n",
    "    Returns:\n",
    "        WER labels list.\n",
    "        [Optional] Epsilin-padded hypothesis if return_eps_padded_hyp set to True.\n",
    "    \"\"\"\n",
    "\n",
    "    # Align reference and hypothesis using \"<eps>\"\n",
    "    aligned_ref, aligned_hyp = texterrors.align_texts(ref, hyp, False)[:-1]\n",
    "\n",
    "    # Determine labels\n",
    "    labels = []\n",
    "    for r, h in zip(aligned_ref, aligned_hyp):\n",
    "        if r == h:\n",
    "            labels.append(\"C\")\n",
    "        elif r == \"<eps>\":\n",
    "            labels.append(\"I\")\n",
    "        elif h == \"<eps>\":\n",
    "            labels.append(\"D\")\n",
    "        else:\n",
    "            labels.append(\"S\")\n",
    "\n",
    "    return labels if not return_eps_padded_hyp else labels, aligned_hyp\n",
    "\n",
    "\n",
    "def fill_confidence_deletions(confidence_scores: List[float], labels: List[str], fill_value: float = 0.0):\n",
    "    \"\"\"Fill confidence scores list with the provided value for deletions.\n",
    "    Assumes that we have no natural confidence scores for deletions.\n",
    "    \n",
    "    Returns:\n",
    "        Confidence scores list with deletion scores.\n",
    "    \"\"\"\n",
    "\n",
    "    assert len(confidence_scores) <= len(labels)\n",
    "\n",
    "    # If the lengths of confidence_scores and labels are equal, then we assume that there are no deletions\n",
    "    if len(confidence_scores) == len(labels):\n",
    "        return confidence_scores\n",
    "\n",
    "    # Insert fill_value into confidence_scores where label == \"D\"\n",
    "    new_confidence_scores = []\n",
    "    score_index = 0\n",
    "    for label in labels:\n",
    "        if label == \"D\":\n",
    "            new_confidence_scores.append(fill_value)\n",
    "        else:\n",
    "            new_confidence_scores.append(confidence_scores[score_index])\n",
    "            score_index += 1\n",
    "    return new_confidence_scores\n",
    "\n",
    "\n",
    "def pretty_pad_word_labels(labels: List[str], words: List[str]):\n",
    "    \"\"\"Pad word labels with dash for pretty printing.\n",
    "    Expects labels and words to have the same length.\n",
    "    \n",
    "    Returns:\n",
    "        Padded labels list.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Check that words and labels without 'D' have the same length\n",
    "    assert len(words) == len(labels)\n",
    "\n",
    "    # Pad the labels with dashes to align them with the words\n",
    "    padded_labels = []\n",
    "    for word, label in zip(words, labels):\n",
    "        label_len = len(word)\n",
    "        left_padding = (label_len - 1) // 2\n",
    "        right_padding = label_len - left_padding - 1\n",
    "        padded_label = \"-\" * left_padding + label + \"-\" * right_padding\n",
    "        padded_labels.append(padded_label)\n",
    "\n",
    "    return padded_labels\n",
    "\n",
    "\n",
    "def _html_paint_word_grey(word: str, shade: str):\n",
    "    if shade == \"black\":\n",
    "        color = \"0,0,0\"\n",
    "    elif shade == \"grey\":\n",
    "        color = \"150,150,150\"\n",
    "    elif shade == \"light_grey\":\n",
    "        color = \"200,200,200\"\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"`shade` has to be one of the following: `black`, `grey`, `light_grey`. Provided: `{shade}`\"\n",
    "        )\n",
    "    return f'<mark style=\"color:rgb({color});background-color:rgb(255,255,255);\">{word}</font></mark>'\n",
    "\n",
    "\n",
    "def pretty_print_transcript_with_confidence(\n",
    "    transcript: str,\n",
    "    confidence_scores: List[float],\n",
    "    threshold: float,\n",
    "    reference: Optional[str] = None,\n",
    "    terminal_width: int = 120,\n",
    "    html: bool = False,\n",
    "):\n",
    "    if html:\n",
    "        shade_if_low_confidence = lambda x, y: _html_paint_word_grey(x, 'light_grey' if y < threshold else 'black')\n",
    "        new_line_mark = \"<br>\"\n",
    "        pretty_print = lambda x: display(HTML(\"<code>\" + new_line_mark.join(x) + \"</code>\"))\n",
    "    else:\n",
    "        shade_if_low_confidence = lambda x, y: colored(x, 'light_grey') if y < threshold else x\n",
    "        new_line_mark = \"\\n\"\n",
    "        pretty_print = lambda x: print(new_line_mark.join(x))\n",
    "    with_labels = reference is not None\n",
    "    transcript_list = transcript.split()\n",
    "    output_lines = []\n",
    "    if with_labels:\n",
    "        reference_list = reference.split()\n",
    "        labels, eps_padded_hyp = get_detailed_wer_labels(reference_list, transcript_list, True)\n",
    "        padded_labels = pretty_pad_word_labels(labels, eps_padded_hyp)\n",
    "        current_line_len = 0\n",
    "        current_word_line = \"\"\n",
    "        current_label_line = \"\"\n",
    "        for word, label, padded_label, score in zip(\n",
    "            eps_padded_hyp, labels, padded_labels, fill_confidence_deletions(confidence_scores, labels)\n",
    "        ):\n",
    "            word_len = len(word)\n",
    "            # shield angle brackets for <eps>\n",
    "            if html and word == \"<eps>\":\n",
    "                word = \"&lt;eps&gt;\"\n",
    "            if current_line_len + word_len + 1 <= terminal_width:\n",
    "                if current_line_len > 0:\n",
    "                    current_line_len += 1\n",
    "                    current_word_line += \" \"\n",
    "                    current_label_line += \"-\"\n",
    "                current_line_len += word_len\n",
    "                current_word_line += shade_if_low_confidence(word, score)\n",
    "                current_label_line += padded_label\n",
    "            else:\n",
    "                output_lines.append(current_word_line + new_line_mark + current_label_line)\n",
    "                current_line_len = word_len\n",
    "                current_word_line = shade_if_low_confidence(word, score)\n",
    "                current_label_line = padded_label\n",
    "        if current_word_line:\n",
    "            output_lines.append(current_word_line + new_line_mark + current_label_line)\n",
    "    else:\n",
    "        current_line_len = 0\n",
    "        current_word_line = \"\"\n",
    "        for word, score in zip(transcript_list, confidence_scores):\n",
    "            word_len = len(word)\n",
    "            # shield angle brackets for <eps>\n",
    "            if html and word == \"<eps>\":\n",
    "                word = \"&lt;eps&gt;\"\n",
    "            if current_line_len + word_len + 1 <= terminal_width:\n",
    "                if current_line_len > 0:\n",
    "                    current_line_len += 1\n",
    "                    current_word_line += \" \"\n",
    "                current_line_len += word_len\n",
    "                current_word_line += shade_if_low_confidence(word, score)\n",
    "            else:\n",
    "                output_lines.append(current_word_line)\n",
    "                current_line_len = word_len\n",
    "                current_word_line = shade_if_low_confidence(word, score)\n",
    "        if current_word_line:\n",
    "            output_lines.append(current_word_line)\n",
    "\n",
    "    pretty_print(output_lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed997bfd",
   "metadata": {
    "id": "dec57a27"
   },
   "source": [
    "## 3.2. Data and model loading\n",
    "This tutorial uses CTC and RNN-T Conformer models trained on LibriSpeech.\n",
    "\n",
    "You can try to use other pre-trained models as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c1a27a",
   "metadata": {
    "id": "b66c60a3"
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from omegaconf import DictConfig, OmegaConf\n",
    "\n",
    "from nemo.collections.asr.models import ASRModel\n",
    "\n",
    "def load_model(name: str):\n",
    "    \"\"\"Load a pre-trained model.\n",
    "\n",
    "    Args:\n",
    "        name: Pre-trained model name.\n",
    "            Reserved names:\n",
    "            - 'ctc' for 'stt_en_conformer_ctc_large_ls'\n",
    "            - 'rnnt' for 'stt_en_conformer_transducer_large_ls'\n",
    "\n",
    "    Returns:\n",
    "        A model loaded into GPU with .eval() mode set.\n",
    "    \"\"\"\n",
    "    if name == \"ctc\":\n",
    "        name = \"stt_en_conformer_ctc_large_ls\"\n",
    "    elif name == \"rnnt\":\n",
    "        name = \"stt_en_conformer_transducer_large_ls\"\n",
    "\n",
    "    model = ASRModel.from_pretrained(model_name=name, map_location=\"cuda:0\")\n",
    "    model.eval()\n",
    "\n",
    "    return model\n",
    "\n",
    "@dataclass\n",
    "class TestSet:\n",
    "    filepaths: List[str]\n",
    "    reference_texts: List[str]\n",
    "    durations: List[float]\n",
    "\n",
    "def load_data(manifest_path: str):\n",
    "    filepaths = []\n",
    "    reference_texts = []\n",
    "    durations = []\n",
    "    with open(manifest_path, \"r\") as f:\n",
    "        for line in f:\n",
    "            item = json.loads(line)\n",
    "            audio_file = item[\"audio_filepath\"]\n",
    "            filepaths.append(str(audio_file))\n",
    "            text = item[\"text\"]\n",
    "            reference_texts.append(text)\n",
    "            durations.append(float(item[\"duration\"]))\n",
    "    return TestSet(filepaths, reference_texts, durations)\n",
    "\n",
    "TEST_MANIFESTS = {\n",
    "    \"test_other\": DATA_DIR + \"/test_other.json\",\n",
    "}\n",
    "\n",
    "\n",
    "# Load data\n",
    "test_sets = {manifest: load_data(path) for manifest, path in TEST_MANIFESTS.items()}\n",
    "\n",
    "# Load model\n",
    "is_rnnt = False\n",
    "# is_rnnt = True\n",
    "\n",
    "model = load_model(\"rnnt\" if is_rnnt else \"ctc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5db700",
   "metadata": {
    "id": "88c3d7ee"
   },
   "source": [
    "## 3.3. Setting up confidence estimation\n",
    "To set up confidence estimation for NeMo ASR models, you need to:\n",
    "1. Initialize _ConfidenceConfig_\n",
    "2. Put the created _ConfidenceConfig_ into the model decoding config.\n",
    "\n",
    "The following cell contains an example of _ConfidenceConfig_ initialization and updating the model's decoding config.\n",
    "\n",
    "For the _ConfidenceConfig_ there are also listed possible values for its parameters.\n",
    "\n",
    "Note that only `strategy=\"greedy\"` (or `greedy_batch` for RNN-T) supports computing confidence scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3e8c11",
   "metadata": {
    "id": "078005f1"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig\n",
    "from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig\n",
    "from nemo.collections.asr.parts.utils.asr_confidence_utils import (\n",
    "    ConfidenceConfig,\n",
    "    ConfidenceConstants,\n",
    "    ConfidenceMethodConfig,\n",
    "    ConfidenceMethodConstants,\n",
    ")\n",
    "from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import (\n",
    "    apply_confidence_parameters,\n",
    "    get_correct_marks,\n",
    "    get_token_targets_with_confidence,\n",
    "    get_word_targets_with_confidence,\n",
    ")\n",
    "\n",
    "\n",
    "# List allowed options for ConfidenceMethodConfig and ConfidenceConfig\n",
    "print(f\"Allowed options for ConfidenceMethodConfig: {ConfidenceMethodConstants.print()}\\n\")\n",
    "print(f\"Allowed options for ConfidenceConfig: {ConfidenceConstants.print()}\\n\")\n",
    "\n",
    "# Initialize ConfidenceConfig and ConfidenceMethodConfig\n",
    "confidence_cfg = ConfidenceConfig(\n",
    "    preserve_frame_confidence=True, # Internally set to true if preserve_token_confidence == True\n",
    "    # or preserve_word_confidence == True\n",
    "    preserve_token_confidence=True, # Internally set to true if preserve_word_confidence == True\n",
    "    preserve_word_confidence=True,\n",
    "    aggregation=\"prod\", # How to aggregate frame scores to token scores and token scores to word scores\n",
    "    exclude_blank=False, # If true, only non-blank emissions contribute to confidence scores\n",
    "    tdt_include_duration=False, # If true, calculate duration confidence for the TDT models\n",
    "    method_cfg=ConfidenceMethodConfig( # Config for per-frame scores calculation (before aggregation)\n",
    "        name=\"max_prob\", # Or \"entropy\" (default), which usually works better\n",
    "        entropy_type=\"gibbs\", # Used only for name == \"entropy\". Recommended: \"tsallis\" (default) or \"renyi\"\n",
    "        alpha=0.5, # Low values (<1) increase sensitivity, high values decrease sensitivity\n",
    "        entropy_norm=\"lin\" # How to normalize (map to [0,1]) entropy. Default: \"exp\"\n",
    "    )\n",
    ")\n",
    "\n",
    "# Alternalively, look at ConfidenceConfig's docstring\n",
    "print(f\"More info on ConfidenceConfig here:\\n{ConfidenceConfig().__doc__}\\n\")\n",
    "\n",
    "# Put the created ConfidenceConfig into the model decoding config via .change_decoding_strategy()\n",
    "model.change_decoding_strategy(\n",
    "    RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n",
    "    if is_rnnt\n",
    "    else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04581687",
   "metadata": {
    "id": "efe0baea"
   },
   "source": [
    "## 3.4. Decode test set and get transcriptions with confidence scores\n",
    "Let's transcribe Librispeech _test-other_ and see what confidence scores are inside."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5f92257",
   "metadata": {
    "id": "ccd8d0de"
   },
   "outputs": [],
   "source": [
    "current_test_set = test_sets[\"test_other\"]\n",
    "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n",
    "if is_rnnt:\n",
    "    transcriptions = transcriptions[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca282352",
   "metadata": {
    "id": "0500514e"
   },
   "source": [
    "For a transcribed hypothesis, there can be `frame_confidence` and aggregated from them `token_confidence` and `word_confidence`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18663384",
   "metadata": {
    "id": "98035fd2"
   },
   "outputs": [],
   "source": [
    "def round_confidence(confidence_number, ndigits=3):\n",
    "    if isinstance(confidence_number, float):\n",
    "        return round(confidence_number, ndigits)\n",
    "    elif len(confidence_number.size()) == 0:  # torch.tensor with one element\n",
    "        return round(confidence_number.item(), ndigits)\n",
    "    elif len(confidence_number.size()) == 1:  # torch.tensor with a list if elements\n",
    "        return [round(c.item(), ndigits) for c in confidence_number]\n",
    "    else:\n",
    "        raise RuntimeError(f\"Unexpected confidence_number: `{confidence_number}`\")\n",
    "\n",
    "\n",
    "tran = transcriptions[0]\n",
    "print(\n",
    "    f\"\"\"    Recognized text: `{tran.text}`\\n\n",
    "    Word confidence: {[round_confidence(c) for c in tran.word_confidence]}\\n\n",
    "    Token confidence: {[round_confidence(c) for c in tran.token_confidence]}\\n\n",
    "    Frame confidence: {\n",
    "        [([round_confidence(cc) for cc in c] if is_rnnt else round_confidence(c)) for c in tran.frame_confidence]\n",
    "    }\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "783e9e2a",
   "metadata": {
    "id": "9613bfc1"
   },
   "source": [
    "Now let's draw the recognition results highlighted according to their confidence scores.\n",
    "\n",
    "There are four options: plain text and HTML with or without WER labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "642fe059",
   "metadata": {
    "id": "a83295ff"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.metrics.wer import word_error_rate, word_error_rate_detail, word_error_rate_per_utt\n",
    "\n",
    "def show_dataset_with_confidence(\n",
    "    indices,\n",
    "    transcriptions,\n",
    "    test_set,\n",
    "    threshold,\n",
    "    filepaths=None,\n",
    "    html_show=False,\n",
    "    min_dur_to_show=0.0,\n",
    "    utt_to_show=10\n",
    "):\n",
    "    utt_shown = 0\n",
    "    for i, _ in indices:\n",
    "        if utt_shown >= utt_to_show:\n",
    "            break\n",
    "        if test_set.durations[i] >= min_dur_to_show:\n",
    "            print(\"=\"*120)\n",
    "            hyp = transcriptions[i].text\n",
    "            scores = transcriptions[i].word_confidence\n",
    "            ref = test_set.reference_texts[i]\n",
    "            pretty_print_transcript_with_confidence(hyp, scores, threshold, ref, html=html_show)\n",
    "            if filepaths is not None:\n",
    "                display(Audio(filepaths[i]))\n",
    "            utt_shown += 1\n",
    "\n",
    "\n",
    "# you can play with these parameters\n",
    "threshold = 0.52\n",
    "# in colab, you may want to use `html_show = True` as non-html colorion displayed incorrectly in colab\n",
    "html_show = is_colab\n",
    "min_dur_to_show = 4.0\n",
    "utt_to_show = 5\n",
    "\n",
    "wer_per_utt, avg_wer = word_error_rate_per_utt([h.text for h in transcriptions], current_test_set.reference_texts)\n",
    "sorted_wer_indices = sorted(enumerate(wer_per_utt), key=lambda x: x[1])[::-1]\n",
    "\n",
    "show_dataset_with_confidence(\n",
    "    indices=sorted_wer_indices,\n",
    "    transcriptions=transcriptions,\n",
    "    test_set=current_test_set,\n",
    "    threshold=threshold,\n",
    "    filepaths=current_test_set.filepaths,\n",
    "    html_show=html_show,\n",
    "    min_dur_to_show=min_dur_to_show,\n",
    "    utt_to_show=utt_to_show\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9468ad3e",
   "metadata": {
    "id": "dbfcb2da"
   },
   "source": [
    "## 3.5. Confidence metrics\n",
    "\n",
    "There are several metrics to evaluate the effectiveness of a confidence estimation method. Some of them consider confidence estimation as a binary classification task. Other measure how close the correct word confidence scores are to $1.0$ and the incorrect word scores are to $0.0$.\n",
    "\n",
    "Some of them are:\n",
    "1. Area Under the Receiver Operating Characteristics Curve ($\\mathrm{AUC}_\\mathrm{ROC}$): class separability metric.\n",
    "2. Area Under the Precision-Recall Curve ($\\mathrm{AUC}_\\mathrm{PR}$): how well the correct words are detected.\n",
    "3. Area Under the Negative Predictive Value vs. True Negative Rate Curve ($\\mathrm{AUC}_\\mathrm{NT}$): how well the incorrect words are detected ($\\mathrm{AUC}_\\mathrm{PR}$ in which errors are treated as positives).\n",
    "4. Normalized Cross Entropy ($\\mathrm{NCE}$): how close of confidence for correct predictions to $1.0$ and of incorrect predictions to $0.0$. It ranges from $-\\infty$ to $1.0$, with negative scores indicating that the conﬁdence method performs worse than the setting confidence score to $1-\\mathrm{WER}$. This metric is also known as Normalized Mutual Information.\n",
    "5. Expected Calibration Error ($\\mathrm{ECE}$): a weighted average over the absolute accuracy/confidence difference. It ranges from $0.0$ to $1.0$ with the best value $0.0$.\n",
    "\n",
    "Metrics based on the Youden's curve (see https://en.wikipedia.org/wiki/Youden%27s_J_statistic) can also be considered. They are:\n",
    "1. Area Under the Youden's curve ($\\mathrm{AUC}_\\mathrm{YC}$): the rate of the effective threshold range (i.e. the adjustability or responsiveness). It ranges from $0.0$ to $1.0$ with the best value $0.5$.\n",
    "2. Maximum of the Youden's curve $\\mathrm{MAX}_\\mathrm{YC}$: the optimal $\\mathrm{TNR}$ vs. $\\mathrm{FNR}$ tradeoff. It's unnormalized version can be used as a criterion for selecting the optimal $\\tau$. It ranges from $0.0$ to $1.0$ with the best value $1.0$.\n",
    "3. The standard deviation of the Youden's curve values ($\\mathrm{STD}_\\mathrm{YC}$): indicates that $\\mathrm{TNR}$ and $\\mathrm{FNR}$ increase at different rates (viz. $\\mathrm{TNR}$ grows faster) as the $\\tau$ increases. It ranges from $0.0$ to $0.5$ with the best value around $0.25$.\n",
    "\n",
    "When selecting/tuning a confidence method, it is recommended to maximize $\\mathrm{AUC}_\\mathrm{ROC}$ first as this is the main metric of confidence estimation quality. Then, for overconfident models, maximizing $\\mathrm{AUC}_\\mathrm{NT}$ should take precedence over $\\mathrm{AUC}_\\mathrm{PR}$. Finally, a trade-off between $\\mathrm{NCE}$/$\\mathrm{ECE}$ and the family of $\\mathrm{YC}$ metrics considered as a compromise between formal correctness and controllability.\n",
    "\n",
    "Let's see how well our confidence performs according to the metrics above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0fa793",
   "metadata": {
    "id": "5d152775"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.utils.confidence_metrics import (\n",
    "    auc_nt,\n",
    "    auc_pr,\n",
    "    auc_roc,\n",
    "    auc_yc,\n",
    "    ece,\n",
    "    nce,\n",
    "    save_confidence_hist,\n",
    "    save_custom_confidence_curve,\n",
    "    save_nt_curve,\n",
    "    save_pr_curve,\n",
    "    save_roc_curve,\n",
    ")\n",
    "\n",
    "\n",
    "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n",
    "correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)]\n",
    "\n",
    "y_true, y_score = np.array(\n",
    "    [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n",
    ").T\n",
    "\n",
    "\n",
    "# output scheme: yc.mean(), yc.max(), yc.std() or yc.mean(), yc.max(), yc.std(), (thresholds, yc)\n",
    "result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=True)\n",
    "# output scheme: ece or ece, (thresholds, ece_curve)\n",
    "results_ece = ece(y_true, y_score, return_curve=True)\n",
    "results = [\n",
    "    auc_roc(y_true, y_score),\n",
    "    auc_pr(y_true, y_score),\n",
    "    auc_nt(y_true, y_score),\n",
    "    nce(y_true, y_score),\n",
    "    results_ece[0],\n",
    "] + list(result_yc[:3])\n",
    "\n",
    "print(\n",
    "    f\"\"\"    AUC_ROC:\\t{results[0]:.5f}\n",
    "    AUC_PR:\\t{results[1]:.5f}\n",
    "    AUC_NT:\\t{results[2]:.5f}\n",
    "    NCE:\\t{results[3]:.5f}\n",
    "    ECE:\\t{results[4]:.5f}\n",
    "    AUC_YC:\\t{results[5]:.5f}\n",
    "    MAX_YC:\\t{results[7]:.5f}\n",
    "    STD_YC:\\t{results[6]:.5f}\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c3f6299",
   "metadata": {
    "id": "4159034d"
   },
   "source": [
    "Confidence metrics for the maximum probability confidence are not that great.\n",
    "\n",
    "Let's re-run and benchmark confidence estimation with the default confidence estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c0e3a9f",
   "metadata": {
    "id": "d2e16f5f"
   },
   "outputs": [],
   "source": [
    "confidence_cfg = ConfidenceConfig(\n",
    "    preserve_word_confidence=True,\n",
    "    preserve_token_confidence=True,\n",
    ")\n",
    "\n",
    "model.change_decoding_strategy(\n",
    "    RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n",
    "    if is_rnnt\n",
    "    else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n",
    ")\n",
    "\n",
    "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n",
    "if is_rnnt:\n",
    "    transcriptions = transcriptions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f1cc77",
   "metadata": {
    "id": "6201ea4d"
   },
   "outputs": [],
   "source": [
    "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n",
    "correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)]\n",
    "\n",
    "y_true, y_score = np.array(\n",
    "    [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n",
    ").T\n",
    "\n",
    "result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=True)\n",
    "results_ece = ece(y_true, y_score, return_curve=True)\n",
    "results = [\n",
    "    auc_roc(y_true, y_score),\n",
    "    auc_pr(y_true, y_score),\n",
    "    auc_nt(y_true, y_score),\n",
    "    nce(y_true, y_score),\n",
    "    results_ece[0],\n",
    "] + list(result_yc[:3])\n",
    "\n",
    "print(\n",
    "    f\"\"\"    AUC_ROC:\\t{results[0]:.5f}\n",
    "    AUC_PR:\\t{results[1]:.5f}\n",
    "    AUC_NT:\\t{results[2]:.5f}\n",
    "    NCE:\\t{results[3]:.5f}\n",
    "    ECE:\\t{results[4]:.5f}\n",
    "    AUC_YC:\\t{results[5]:.5f}\n",
    "    MAX_YC:\\t{results[7]:.5f}\n",
    "    STD_YC:\\t{results[6]:.5f}\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ab2b130",
   "metadata": {
    "id": "498e03d0"
   },
   "source": [
    "Note that despite the overall improvement, $NCE$ and $ECE$ have gotten worse. This is due to class imbalance caused by low WER."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f96cea04",
   "metadata": {
    "id": "45856cba"
   },
   "source": [
    "Now, let's draw $\\mathrm{ROC}$ as well as histograms of correctly and incorrectly recognized words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81844713",
   "metadata": {
    "id": "ff049043"
   },
   "outputs": [],
   "source": [
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "\n",
    "plot_dir = TemporaryDirectory()\n",
    "os.makedirs(plot_dir.name, exist_ok=True)\n",
    "\n",
    "mask_correct = y_true == 1\n",
    "y_score_correct = y_score[mask_correct]\n",
    "y_score_incorrect = y_score[~mask_correct]\n",
    "\n",
    "# histogram of the correct distribution\n",
    "save_confidence_hist(y_score_correct, plot_dir.name, \"hist_correct\")\n",
    "# histogram of the incorrect distribution\n",
    "save_confidence_hist(y_score_incorrect, plot_dir.name, \"hist_incorrect\")\n",
    "# AUC-ROC curve\n",
    "save_roc_curve(y_true, y_score, plot_dir.name, \"roc\")\n",
    "\n",
    "\n",
    "display(\n",
    "    Image(filename=os.path.join(plot_dir.name, \"hist_correct.png\"), retina=True),\n",
    "    Image(filename=os.path.join(plot_dir.name, \"hist_incorrect.png\"), retina=True),\n",
    "    Image(filename=os.path.join(plot_dir.name, \"roc.png\"), retina=True),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841a27ca",
   "metadata": {},
   "source": [
    "Optionally, you can look at curves for other metrics ($\\mathrm{PR}$, $\\mathrm{NT}$, $\\mathrm{ECE}$, and $\\mathrm{YC}$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6164e8f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AUC-PR curve\n",
    "save_pr_curve(y_true, y_score, plot_dir.name, \"pr\")\n",
    "# AUC-NT curve\n",
    "save_nt_curve(y_true, y_score, plot_dir.name, \"nt\")\n",
    "# ECE curve\n",
    "ece_thresholds, ece_values = results_ece[-1]\n",
    "ece_values /= max(ece_values)\n",
    "save_custom_confidence_curve(\n",
    "    ece_thresholds, ece_values, plot_dir.name, \"ece\", \"Threshold\", \"|Accuracy − Confidence score|\"\n",
    ")\n",
    "# AUC-YC curve\n",
    "yc_thresholds, yc_values = result_yc[-1]\n",
    "save_custom_confidence_curve(\n",
    "    yc_thresholds, yc_values, plot_dir.name, \"yc\", \"Threshold\", \"True positive rate − False Positive Rate\"\n",
    ")\n",
    "\n",
    "\n",
    "display(\n",
    "    Image(filename=os.path.join(plot_dir.name, \"pr.png\"), retina=True),\n",
    "    Image(filename=os.path.join(plot_dir.name, \"nt.png\"), retina=True),\n",
    "    Image(filename=os.path.join(plot_dir.name, \"ece.png\"), retina=True),\n",
    "    Image(filename=os.path.join(plot_dir.name, \"yc.png\"), retina=True),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f63a172",
   "metadata": {
    "id": "ad78630a"
   },
   "source": [
    "You can use `scripts/speech_recognition/confidence/benchmark_asr_confidence.py` to find optimal confidence hyperparameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d9a822d",
   "metadata": {
    "id": "15e25521"
   },
   "source": [
    "# 4. Confidence applications"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ab6e666",
   "metadata": {
    "id": "dbb82877"
   },
   "source": [
    "## 4.1. Small WER improvement\n",
    "\n",
    "Good confidence scores can slightly reduce WER by removing low confidence words from recognition results.\n",
    "\n",
    "Consider the following example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4038863c",
   "metadata": {
    "id": "02eb4e1f"
   },
   "source": [
    "Let's look at the detailed WER of the transcribed test set before and after removing words with low confidence score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "204d36ac",
   "metadata": {
    "id": "fdf790b5"
   },
   "outputs": [],
   "source": [
    "drop_low_confidence_words = lambda x, y, z: \" \".join([xx for xx, yy in zip(x.split(), y) if yy >= z])\n",
    "\n",
    "\n",
    "threshold = 0.001\n",
    "\n",
    "wer_initial = word_error_rate_detail([h.text for h in transcriptions], current_test_set.reference_texts)\n",
    "print(\n",
    "    f\"\"\"WER detail before removing low confidence words:\n",
    "    WER:\\t{wer_initial[0]:.5f}\n",
    "    INS_rate:\\t{wer_initial[2]:.5f}\n",
    "    DEL_rate:\\t{wer_initial[3]:.5f}\n",
    "    SUB_rate:\\t{wer_initial[4]:.5f}\"\"\"\n",
    ")\n",
    "\n",
    "wer_conf_dropped = word_error_rate_detail(\n",
    "    [drop_low_confidence_words(hyp.text, hyp.word_confidence, threshold) for hyp in transcriptions],\n",
    "    current_test_set.reference_texts,\n",
    ")\n",
    "print(\n",
    "    f\"\"\"WER detail after removing low confidence words:\n",
    "    WER:\\t{wer_conf_dropped[0]:.5f}\n",
    "    INS_rate:\\t{wer_conf_dropped[2]:.5f}\n",
    "    DEL_rate:\\t{wer_conf_dropped[3]:.5f}\n",
    "    SUB_rate:\\t{wer_conf_dropped[4]:.5f}\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f153cdd",
   "metadata": {
    "id": "28ac85b1"
   },
   "source": [
    "You can see that with the right (in this example, extremely low) `threshold` can reduce WER by a tiny bit, reducing insertions and substitutions yet increasing deletions.\n",
    "\n",
    "Now let's see how to find the optimal threshold.\n",
    "\n",
    "The most commonly used method for automatically determining the optimal cutoff threshold is taking the value which delivers the maximum of the unnormalized Youden's curve. This method allows you to remove the largest number of incorrect entities, sacrificing the minimum number of correct entities.\n",
    "\n",
    "However, the unnormalized $\\mathrm{MAX}_\\mathrm{YC}$ method does not work well for the purpose of the WER reduction. Let's compare this method to explicitly minimizing WER with respect to a threshold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19147b4a",
   "metadata": {
    "id": "9b81e449"
   },
   "outputs": [],
   "source": [
    "from joblib import Parallel, delayed\n",
    "from multiprocessing import cpu_count\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "def max_unnnormalized_yc(\n",
    "    y_true: Union[List[int], np.ndarray],\n",
    "    y_score: Union[List[float], np.ndarray],\n",
    "    n_bins: int = 100,\n",
    "    start: float = 0.0,\n",
    "    stop: float = 1.0,\n",
    "):\n",
    "    \"\"\"Calculate the maximum of the unnormalized Youden's curve.\n",
    "    \"\"\"\n",
    "    y_true = np.array(y_true)\n",
    "    y_score = np.array(y_score)\n",
    "    thresholds = np.linspace(start, stop, n_bins + 1)\n",
    "    assert len(y_true) == len(y_score)\n",
    "    assert np.all(y_true >= 0) and np.all(y_true <= 1)\n",
    "    if np.all(y_true == 0) or np.all(y_true == 1):\n",
    "        return 0.0, 0.0\n",
    "    mask_correct = y_true == 1\n",
    "    y_score_correct = y_score[mask_correct]\n",
    "    y_score_incorrect = y_score[~mask_correct]\n",
    "    unnnormalized_yc = []\n",
    "    for threshold in thresholds:\n",
    "        tn = len((y_score_incorrect < threshold).nonzero()[0])\n",
    "        fn = len((y_score_correct < threshold).nonzero()[0])\n",
    "        unnnormalized_yc.append((threshold, tn - fn))\n",
    "    return max(unnnormalized_yc, key=lambda x: x[1])[0]\n",
    "\n",
    "\n",
    "def min_wer(ref: List[str], transcriptions, n_bins: int = 100, start: float = 0.0, stop: float = 1.0):\n",
    "    \"\"\"Find the threshold value that delivers the minimum WER.\n",
    "    \"\"\"\n",
    "    thresholds = np.linspace(start, stop, n_bins + 1)\n",
    "    hyp = [(hyp.text, hyp.word_confidence) for hyp in transcriptions]\n",
    "    _get_wer = lambda x, y, z: (x, word_error_rate_detail([drop_low_confidence_words(yy[0], yy[1], x) for yy in y], z)[0])\n",
    "    wers = Parallel(n_jobs=cpu_count())(delayed(_get_wer)(threshold, hyp, ref) for threshold in tqdm(thresholds))\n",
    "    return min(wers, key=lambda x: x[1])\n",
    "\n",
    "\n",
    "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n",
    "correct_marks = [\n",
    "    get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)\n",
    "]\n",
    "y_true, y_score = np.array(\n",
    "    [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n",
    ").T\n",
    "\n",
    "threshold_yc = max_unnnormalized_yc(y_true, y_score)\n",
    "yc_wer_value = word_error_rate(\n",
    "    [drop_low_confidence_words(hyp.text, hyp.word_confidence, threshold_yc) for hyp in transcriptions],\n",
    "    current_test_set.reference_texts,\n",
    ")\n",
    "threshold_min_wer, min_wer_value = min_wer(current_test_set.reference_texts, transcriptions, stop=0.1)\n",
    "\n",
    "print(\n",
    "    f\"\"\"    Initial WER: {wer_initial[0]:.5f}\n",
    "    Optimal threshold and WER based on the Youden's curve: {threshold_yc}, {yc_wer_value:.5f}\n",
    "    Optimal threshold for the minimum WER: {threshold_min_wer}, {min_wer_value:.5f}\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "425d010e",
   "metadata": {
    "id": "3b278d2d"
   },
   "source": [
    "As you can see, the optimal cutoff threshold as the maximum of the Youden's curve makes WER significantly worse, and the optimal threshold for the minimum WER is near zero.\n",
    "\n",
    "Let's use a different confidence estimation setup to see if we can improve WER at least a bit further."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d080686",
   "metadata": {
    "id": "39f72c78"
   },
   "outputs": [],
   "source": [
    "confidence_cfg = ConfidenceConfig(\n",
    "    preserve_word_confidence=True,\n",
    "    preserve_token_confidence=True,\n",
    "    aggregation=\"min\",\n",
    "    method_cfg=DictConfig({\"entropy_type\": \"tsallis\", \"alpha\": 1.5, \"entropy_norm\": \"lin\"}),\n",
    ")\n",
    "\n",
    "model.change_decoding_strategy(\n",
    "    RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n",
    "    if is_rnnt\n",
    "    else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n",
    ")\n",
    "\n",
    "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n",
    "if is_rnnt:\n",
    "    transcriptions = transcriptions[0]\n",
    "\n",
    "threshold_min_wer, min_wer_value = min_wer(current_test_set.reference_texts, transcriptions)\n",
    "\n",
    "print(\n",
    "    f\"\"\"    Initial WER: {wer_initial[0]:.5f}\n",
    "    Optimal threshold for the minimum WER: {threshold_min_wer}, {min_wer_value:.5f}\n",
    "    \"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3c9cc02",
   "metadata": {
    "id": "e00581b1"
   },
   "source": [
    "Overall, such an improvement in WER is too small to be considered. However, this opens up the possibility of improving WER through the use of more accurate confidence estimation methods."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "694d1752",
   "metadata": {
    "id": "f9f89665"
   },
   "source": [
    "## 4.2. Reducing hallucinations with confidence scores\n",
    "\n",
    "One common application of confidence scores is the removal of recognition hallucinations.\n",
    "\n",
    "Let's see how this can be done."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98a1ef83",
   "metadata": {
    "id": "c1c28379"
   },
   "source": [
    "Firstly, let's obtain a dataset on which the ASR model can hallucinate.\n",
    "\n",
    "Here we make it from the librosa examples, reversing them and convolving with each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12a5041",
   "metadata": {
    "id": "3b0a0b4c"
   },
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "import json\n",
    "import librosa\n",
    "import soundfile as sf\n",
    "\n",
    "def cyclic_sum(x, y):\n",
    "    if x.shape[0] < y.shape[0]:\n",
    "        x, y = y, x\n",
    "    if x.shape[0] > y.shape[0]:\n",
    "        y = np.take(y, range(0, x.shape[0]), mode='wrap')\n",
    "    return x + y\n",
    "\n",
    "def generate_noise_examples(example_list: List[str], save_dir: str, samplerate: int = 16000):\n",
    "    \"\"\"Generate noise examples with librosa.\n",
    "    It loads the selected example, inverts and perturbs them with each other.\n",
    "\n",
    "    Returns:\n",
    "        A manifest with the noise wavs.\n",
    "    \"\"\"\n",
    "    samples = {ex: librosa.core.load(librosa.util.example(key=ex, hq=True), sr=samplerate)[0] \n",
    "               for ex in example_list}\n",
    "    noise_samples = {\"_\".join([left, right]): cyclic_sum(samples[left][::-1], samples[right][::-1]) \n",
    "                     for left, right in combinations(samples.keys(), 2)}\n",
    "\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "    manifest = os.path.join(save_dir, \"manifest.json\")\n",
    "    with open(manifest, \"tw\", encoding=\"utf-8\") as fout:\n",
    "        for k, v in noise_samples.items():\n",
    "            audio_path = os.path.join(save_dir, f\"{k}.wav\")\n",
    "            sf.write(audio_path, v, samplerate=samplerate)\n",
    "            metadata = {\n",
    "                \"audio_filepath\": audio_path,\n",
    "                \"duration\": librosa.core.get_duration(y=v, sr=samplerate),\n",
    "                \"label\": \"noise\",\n",
    "                \"text\": \"_\"\n",
    "            }\n",
    "            json.dump(metadata, fout)\n",
    "            fout.write('\\n')\n",
    "\n",
    "    return manifest\n",
    "\n",
    "librosa_list_examples = ['brahms',\n",
    "                         'choice',\n",
    "                         'fishin',\n",
    "                         'humpback',\n",
    "                         'libri1',\n",
    "                         'libri2',\n",
    "                         'libri3',\n",
    "                         'nutcracker',\n",
    "                         'pistachio',\n",
    "                         'robin',\n",
    "                         'sweetwaltz',\n",
    "                         'trumpet',\n",
    "                         'vibeace']\n",
    "sr = 16000\n",
    "\n",
    "noise_dir = os.path.join(DATA_DIR, \"noise\")\n",
    "noise_manifest = generate_noise_examples(librosa_list_examples, noise_dir, sr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f28da61f",
   "metadata": {},
   "source": [
    "The original examples contain speech, music, or noise. The resulting audio recordings are considered to contain no recognizable speech.\n",
    "\n",
    "You can listen to an example of the audios."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4e7007",
   "metadata": {},
   "outputs": [],
   "source": [
    "noise_data = load_data(noise_manifest)\n",
    "\n",
    "display(Audio(noise_data.filepaths[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1db80ae4",
   "metadata": {
    "id": "f7f9ddca"
   },
   "source": [
    "Now let's transcribe our new data, setting the default confidence estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a872926",
   "metadata": {
    "id": "60f39094"
   },
   "outputs": [],
   "source": [
    "confidence_cfg = ConfidenceConfig(\n",
    "    preserve_word_confidence=True,\n",
    "    preserve_token_confidence=True,\n",
    ")\n",
    "\n",
    "model.change_decoding_strategy(\n",
    "    RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n",
    "    if is_rnnt\n",
    "    else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n",
    ")\n",
    "\n",
    "noise_transcriptions = model.transcribe(\n",
    "    audio=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n",
    ")\n",
    "if is_rnnt:\n",
    "    noise_transcriptions = noise_transcriptions[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d097ca6",
   "metadata": {
    "id": "2f192186"
   },
   "source": [
    "On a fully non-speech dataset, hallucinations can be measured as the Word Insertions per Second (WIS) value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c6321c",
   "metadata": {
    "id": "3589da00"
   },
   "outputs": [],
   "source": [
    "def word_insertions_per_second(texts: List[str], durations: List[float]):\n",
    "    \"\"\"Calculate the Word Insertions per Second (WIS) value for the given recognition results \n",
    "    and their corresponding audio duration.\n",
    "    \"\"\"\n",
    "    assert len(texts) == len(durations)\n",
    "\n",
    "    wis_per_utt = [len(text.split(\" \")) / duration for text, duration in zip(texts, durations)]\n",
    "    return sum(wis_per_utt) / len(wis_per_utt), wis_per_utt\n",
    "\n",
    "wis, wis_per_utt = word_insertions_per_second([t.text for t in noise_transcriptions], noise_data.durations)\n",
    "print(f\"Original Word Insertions per Second: {wis:.5f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcf44daf",
   "metadata": {
    "id": "a0d8135d"
   },
   "source": [
    "Now, the ability of a confidence estimator to detect hallucinations is computed as the Hallucination Detection Rate (HDR).\n",
    "\n",
    "It shows how many of all hallucinations can be removed, provided that no more than some fixed percentage of correct words are erroneously removed (under normal recognition conditions).\n",
    "\n",
    "HDR is another name of the metric $\\mathrm{TNR}_{FNR=e}$ which is calculated as $\\mathrm{TNR}(Y,\\tau): \\mathrm{FNR}(X,\\tau) \\approx e$, where $X$ is the dataset with supervision (to tune $\\tau$) and $Y$ is the noise-only dataset. Typical $e$ value is 0.05.\n",
    "\n",
    "Let's compute HDR and the new WIS.\n",
    "\n",
    "The generated dataset is clearly distinct from speech, so $e=0.01$ is sufficient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dac1f7d",
   "metadata": {
    "id": "0612ccf6"
   },
   "outputs": [],
   "source": [
    "def hdr(\n",
    "    y_true_speech: Union[List[int], np.ndarray],\n",
    "    y_score_speech: Union[List[float], np.ndarray],\n",
    "    y_score_noise: Union[List[float], np.ndarray],\n",
    "    max_fnr: float = 0.05,\n",
    "    n_bins: int = 100,\n",
    ") -> Tuple[float, float]:\n",
    "    \"\"\"Compute Hallucination Detection Rate (HDR) from prediction scores.\n",
    "\n",
    "    Returns:\n",
    "        tnr: True-Negateve Rate for HDR\n",
    "        threshold_hdr: Optomal threshold \n",
    "    \"\"\"\n",
    "    y_true_speech = np.array(y_true_speech)\n",
    "    y_score_speech = np.array(y_score_speech)\n",
    "    y_score_noise = np.array(y_score_noise)\n",
    "    thresholds = np.linspace(0, 1, n_bins + 1)\n",
    "    assert y_true_speech.shape[0] == y_score_speech.shape[0]\n",
    "    assert np.all(y_true_speech >= 0) and np.all(y_true_speech <= 1)\n",
    "    if np.all(y_true_speech == 0) or np.all(y_true_speech == 1):\n",
    "        return 0.0, 0.0\n",
    "    mask_correct = y_true_speech == 1\n",
    "    count_correct = max(mask_correct.nonzero()[0].shape[0], 1)\n",
    "    y_score_correct = y_score_speech[mask_correct]\n",
    "    threshold_hdr = 0.0\n",
    "    for threshold in thresholds:\n",
    "        fnr = (y_score_correct < threshold).nonzero()[0].shape[0] / count_correct\n",
    "        if fnr <= max_fnr:\n",
    "            threshold_hdr = threshold\n",
    "        else:\n",
    "            break\n",
    "    tnr = (y_score_noise < threshold_hdr).nonzero()[0].shape[0] / y_score_noise.shape[0]\n",
    "    return tnr, threshold_hdr\n",
    "\n",
    "\n",
    "# e\n",
    "max_fnr = 0.01\n",
    "\n",
    "correct_marks = [\n",
    "    mark for r, h in zip(current_test_set.reference_texts, transcriptions) for mark in get_correct_marks(r.split(), h.words)\n",
    "]\n",
    "y_score_speech = [w for h in transcriptions for w in h.word_confidence]\n",
    "y_score_noise = [w for h in noise_transcriptions for w in h.word_confidence]\n",
    "hdr_score, threshold_hdr = hdr(correct_marks, y_score_speech, y_score_noise, max_fnr=max_fnr)\n",
    "wis_new = wis - wis * hdr_score\n",
    "\n",
    "hdr_score, wis_new\n",
    "print(\n",
    "    f\"\"\"    Hallucination Detection Rate for max_fnr={max_fnr} : {hdr_score:.5f}\n",
    "    New Word Insertions Per Second: {wis_new:.5f}\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "443938bc",
   "metadata": {
    "id": "418297d6"
   },
   "source": [
    "Finally, let's print the noisy utterances to see if any more hallucinations persist."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde9e7db",
   "metadata": {
    "id": "3815e8e3"
   },
   "outputs": [],
   "source": [
    "sorted_wis_indices = sorted(enumerate(wis_per_utt), key=lambda x: x[1])[::-1]\n",
    "\n",
    "show_dataset_with_confidence(\n",
    "    indices=sorted_wis_indices,\n",
    "    transcriptions=noise_transcriptions,\n",
    "    test_set=noise_data,\n",
    "    threshold=threshold_hdr,\n",
    "    filepaths=noise_data.filepaths,\n",
    "    html_show=is_colab,\n",
    "    min_dur_to_show=0.0,\n",
    "    utt_to_show=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66f92938",
   "metadata": {
    "id": "0ac58ef2"
   },
   "source": [
    "# Summary\n",
    "This tutorial covered the basics of ASR confidence estimation and two examples of using ASR word confidence: WER reduction and hallucinations removal.\n",
    "\n",
    "You can follow this tutorial on [ASR Confidence-based Ensembles](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Confidence_Ensembles.ipynb) to see another important application of ASR confidence estimation."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
